import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

# --------------------------------------------------------------------------------------
# Embedding model configuration
# --------------------------------------------------------------------------------------
EMBED_MODEL = "text-embedding-3-small"

# Path for global diagnostics (created in output_dir later)
EMBED_DIAG_FILENAME = "embedding_diagnostics.txt"

# Azure OpenAI Configuration (copied from gpt_static_thisruns.py)
api_version = "2024-02-15-preview"
config_dict: Dict[str, str] = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}

def _get_embeddings(
    texts: List[str],
    batch_size: int = 96,
    diag_path: Optional[str] = None,
    max_retries: int = 5,
) -> List[Optional[np.ndarray]]:
    # Pre-allocate output so ordering is preserved even with failures.
    results: List[Optional[np.ndarray]] = [None] * len(texts)

    # Build Azure client once (outside the loop for efficiency)
    try:
        from openai import AzureOpenAI  # local import to avoid dependency when not needed
        client = AzureOpenAI(
            api_key=os.getenv("OPENAI_API_KEY", config_dict["api_key"]),
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )
    except Exception as e:
        if diag_path:
            with open(diag_path, "a") as f:
                f.write(f"[CLIENT-INIT-ERROR] Failed to create AzureOpenAI client: {e}\n")
        return results  # all None

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]

        # Retry loop for this chunk only
        attempt = 0
        while attempt <= max_retries:
            try:
                resp = client.embeddings.create(model=EMBED_MODEL, input=chunk)
                resp.data.sort(key=lambda x: x.index)  # preserve original order

                for i, d in enumerate(resp.data):
                    results[start + i] = np.array(d.embedding, dtype=np.float32)

                break  # success, move to next batch

            except Exception as e:
                attempt += 1
                # Parse retry-after seconds if available (Azure puts it in the
                # error message sometimes) – default to 5 * attempt seconds.
                wait_secs = 5 * attempt
                if "retry" in str(e).lower():
                    # crude extraction of the first integer in the message
                    import re as _re

                    m = _re.search(r"retry after (\\d+)", str(e).lower())
                    if m:
                        wait_secs = int(m.group(1))

                if diag_path:
                    with open(diag_path, "a") as f:
                        f.write(
                            f"[EMBEDDING-ERROR] Batch {start}-{start+len(chunk)-1} attempt {attempt}/{max_retries}: {e}. Waiting {wait_secs}s.\n"
                        )

                if attempt > max_retries:
                    # Give up on this batch – leave None placeholders.
                    break

                time.sleep(wait_secs)

        # End retry loop

        # Respect base rate-limit between *successful* calls only.
        if attempt == 0 or results[start] is not None:
            time.sleep(1)  # more conservative than 0.1s previously

    return results


def _cosine(u: np.ndarray, v: np.ndarray) -> float:
    """Cosine similarity between two vectors."""
    return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(user_prompt: str, model, tokenizer, args) -> str:
    """Call Qwen model for chat completion with only user prompt."""
    messages = [{"role": "user", "content": "/no_think" + user_prompt}]

    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static evaluation for tweet engagement.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON with pre-computed nearest neighbours.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic (chunk mode)
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Model ONCE (before any evaluation path)
    # -----------------------------------------------------------------------------
    if model is None or tokenizer is None:
        model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto"
        )

    # Tweet evaluation is the only task, so we run it directly.
    sim_map = _get_similarity_map(args)
    run_tweet_evaluation(args, sim_map=sim_map)

def _get_similarity_map(args):
    """Load precomputed similarity map if provided, else return None."""
    sim_map = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")
    return sim_map

def _tweet_key(rec, idx):
    """Return a unique key for a tweet record for similarity lookup. By default, use the index as string."""
    return str(idx)

def _extract_brand_and_date(text: str):
    # Brand
    brand_match = re.search(r"brand\s*:\s*([A-Za-z0-9_\-]+)", text, flags=re.IGNORECASE)
    brand = brand_match.group(1).lower() if brand_match else "unknown"

    # Year
    year_match = re.search(r"\b(19|20)\d{2}\b", text)
    year = year_match.group(0) if year_match else "unknown"

    return brand, year

def run_tweet_evaluation(args, sim_map=None):
    if sim_map is None:
        return

    # Resolve dataset paths
    dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]


    overall_out_dir = args.output_dir or "tweet_static_results"
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}" if 'slice_start' in locals() else ""
        out_path = os.path.join(overall_out_dir, f"tweet_results_{dataset_name}{slice_suffix}.json")

        correct = 0
        brand_stats = {}
        time_stats = {}
        all_results = []

        # Precompute all possible indices for neighbor selection
        all_indices = list(range(len(records)))

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            prompt_text = rec.get("prompt", "")
            gt_resp = rec.get("response", "")

            gt_label = "high" if re.search(r"high likes", gt_resp, flags=re.IGNORECASE) else "low"

            # --- FEW-SHOT EXAMPLES: Always select 5 random, but ensure mix of high/low ---
            pool = [i for i in all_indices if i != idx]
            max_attempts = 10
            for attempt in range(max_attempts):
                neighbor_ids = random.sample(pool, k=min(5, len(pool)))
                labels = [1 if re.search(r"high likes", records[nid].get("response", ""), flags=re.IGNORECASE) else 0 for nid in neighbor_ids]
                if any(labels) and not all(labels):
                    break  # At least one high and one low
            else:
                # Fallback: force at least one high and one low if possible
                highs = [i for i in pool if re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                lows = [i for i in pool if not re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                neighbor_ids = []
                if highs: neighbor_ids.append(random.choice(highs))
                if lows: neighbor_ids.append(random.choice(lows))
                rest = [i for i in pool if i not in neighbor_ids]
                neighbor_ids += random.sample(rest, k=min(5-len(neighbor_ids), len(rest)))
            log_msg = f"[RANDOM_MIXED] Used random mixed neighbors for idx {idx}: {neighbor_ids}"

            example_blocks = []
            for sid in neighbor_ids:
                ex = records[sid]
                text = ex["prompt"]
                # For tweet tasks, we don't have a numeric score, so just show the prompt
                example_blocks.append(f"{text}")
            examples_text = "\n---\n".join(example_blocks)

            user_prompt = (
                "Below are five example tweets. After these, you'll see a new tweet. "
                "Predict whether it will receive high or low likes on Twitter. "
                "Return two lines exactly:\nReason: <brief>\nAnswer: [High / Low]\n"
                "Examples:\n" + examples_text +
                "\n---\n" + prompt_text
            )

            resp_text = verbalize(user_prompt, model, tokenizer, args)
            # Extract high/low from the answer line
            match = re.search(r"Answer:\s*(high|low)", resp_text, flags=re.IGNORECASE)
            pred_label = match.group(1).lower() if match else None

            # Accuracy bookkeeping
            is_correct = pred_label == gt_label
            if is_correct:
                correct += 1

            brand, year = _extract_brand_and_date(prompt_text)

            # Brand stats
            b_stats = brand_stats.setdefault(brand, {"correct": 0, "total": 0})
            b_stats["total"] += 1
            if is_correct:
                b_stats["correct"] += 1

            # Time stats (year)
            t_stats = time_stats.setdefault(year, {"correct": 0, "total": 0})
            t_stats["total"] += 1
            if is_correct:
                t_stats["correct"] += 1

            all_results.append({
                "prompt": prompt_text,
                "ground_truth": gt_label,
                "response": resp_text,
                "predicted_label": pred_label,
                "neighbor_ids": neighbor_ids,
                "neighbor_log": log_msg,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        # — Report accuracies
        total = len(records)
        overall_acc = correct / total if total else 0.0
        print(f"Overall accuracy for {dataset_name}: {overall_acc:.3f} ({correct}/{total})")

        print("\nAccuracy by brand:")
        for b, st in sorted(brand_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {b}: {acc:.3f} ({st['correct']}/{st['total']})")

        print("\nAccuracy by year:")
        for y, st in sorted(time_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {y}: {acc:.3f} ({st['correct']}/{st['total']})")

    print("\n[INFO] Tweet evaluation complete. Exiting.")
    sys.exit(0)

if __name__ == "__main__":
    main()


            

            

        
        
        

        

